{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adaptive RAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "var userHomeDir = System.getProperty(\"user.home\");\n",
    "var localRespoUrl = \"file://\" + userHomeDir + \"/.m2/repository/\";\n",
    "var langchain4jVersion = \"1.0.1\";\n",
    "var langchain4jbeta = \"1.0.1-beta6\";\n",
    "var langgraph4jVersion = \"1.6-SNAPSHOT\";"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%dependency /add-repo local \\{localRespoUrl} release|never snapshot|always\n",
    "// %dependency /list-repos\n",
    "%dependency /add org.slf4j:slf4j-jdk14:2.0.9\n",
    "%dependency /add org.bsc.langgraph4j:langgraph4j-core:\\{langgraph4jVersion}\n",
    "%dependency /add org.bsc.langgraph4j:langgraph4j-langchain4j:\\{langgraph4jVersion}\n",
    "%dependency /add dev.langchain4j:langchain4j:\\{langchain4jVersion}\n",
    "%dependency /add dev.langchain4j:langchain4j-open-ai:\\{langchain4jVersion}\n",
    "\n",
    "%dependency /resolve"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Initialize Logger**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "try( var file = new java.io.FileInputStream(\"./logging.properties\")) {\n",
    "    java.util.logging.LogManager.getLogManager().readConfiguration( file );\n",
    "}\n",
    "\n",
    "var log = org.slf4j.LoggerFactory.getLogger(\"AdaptiveRag\");\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test Issue [#32](https://github.com/bsorrentino/langgraph4j/issues/32)\n",
    "\n",
    "Issue concerns a problem on `AdaptiveRag` implementation referred to `AnswerGrader` task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dev.langchain4j.model.chat.ChatModel;\n",
    "import dev.langchain4j.model.input.Prompt;\n",
    "import dev.langchain4j.model.input.structured.StructuredPrompt;\n",
    "import dev.langchain4j.model.input.structured.StructuredPromptProcessor;\n",
    "import dev.langchain4j.model.openai.OpenAiChatModel;\n",
    "import dev.langchain4j.model.output.structured.Description;\n",
    "import dev.langchain4j.service.AiServices;\n",
    "import dev.langchain4j.service.SystemMessage;\n",
    "import java.time.Duration;\n",
    "import java.util.function.Function;\n",
    "\n",
    "\n",
    "public class AnswerGrader implements Function<AnswerGrader.Arguments,AnswerGrader.Score> {\n",
    "\n",
    "    static final String MODELS[] =  { \"gpt-3.5-turbo-0125\", \"gpt-4o-mini\" };\n",
    "\n",
    "    /**\n",
    "     * Binary score to assess answer addresses question.\n",
    "     */\n",
    "    public static class Score {\n",
    "\n",
    "        @Description(\"Answer addresses the question, 'yes' or 'no'\")\n",
    "        public String binaryScore;\n",
    "\n",
    "        @Override\n",
    "        public String toString() {\n",
    "            return \"Score: \" + binaryScore;\n",
    "        }\n",
    "    }\n",
    "\n",
    "    @StructuredPrompt(\"\"\"\n",
    "User question: \n",
    "\n",
    "{{question}}\n",
    "\n",
    "LLM generation: \n",
    "\n",
    "{{generation}}\n",
    "\"\"\")\n",
    "    record Arguments(String question, String generation) {\n",
    "    }\n",
    "\n",
    "    interface Service {\n",
    "\n",
    "        @SystemMessage(\"\"\"\n",
    "You are a grader assessing whether an answer addresses and/or resolves a question. \n",
    "\n",
    "Give a binary score 'yes' or 'no'. Yes, means that the answer resolves the question otherwise return 'no'\n",
    "        \"\"\")\n",
    "        Score invoke(String userMessage);\n",
    "    }\n",
    "\n",
    "    String openApiKey;\n",
    "\n",
    "    @Override\n",
    "    public Score apply(Arguments args) {\n",
    "        var chatLanguageModel = OpenAiChatModel.builder()\n",
    "                .apiKey( System.getenv(\"OPENAI_API_KEY\")  )\n",
    "                .modelName( MODELS[1] )\n",
    "                .timeout(Duration.ofMinutes(2))\n",
    "                .logRequests(true)\n",
    "                .logResponses(true)\n",
    "                .maxRetries(2)\n",
    "                .temperature(0.0)\n",
    "                .maxTokens(2000)\n",
    "                .build();\n",
    "\n",
    "\n",
    "        Service service = AiServices.create(Service.class, chatLanguageModel);\n",
    "\n",
    "        Prompt prompt = StructuredPromptProcessor.toPrompt(args);\n",
    "\n",
    "        log.trace( \"prompt: {}\", prompt.text() );\n",
    "        \n",
    "        return service.invoke(prompt.text());\n",
    "    }\n",
    "\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "prompt: User question:\n",
      "\n",
      "What are the four operations ? \n",
      "\n",
      "LLM generation:\n",
      "\n",
      "LLM means Large Language Model\n",
      " \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Score: no"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "var grader = new AnswerGrader();\n",
    "\n",
    "var args = new AnswerGrader.Arguments( \"What are the four operations ? \", \"LLM means Large Language Model\" );\n",
    "grader.apply( args );\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "prompt: User question:\n",
      "\n",
      "What are the four operations\n",
      "\n",
      "LLM generation:\n",
      "\n",
      "There are four basic operations: addition, subtraction, multiplication, and division.\n",
      " \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Score: yes"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "var args = new AnswerGrader.Arguments( \"What are the four operations\", \"There are four basic operations: addition, subtraction, multiplication, and division.\" );   \n",
    "grader.apply( args );\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "prompt: User question:\n",
      "\n",
      "What player at the Bears expected to draft first in the 2024 NFL draft?\n",
      "\n",
      "LLM generation:\n",
      "\n",
      "The Bears selected USC quarterback Caleb Williams with the No. 1 pick in the 2024 NFL Draft.\n",
      " \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Score: yes"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "var args = new AnswerGrader.Arguments( \"What player at the Bears expected to draft first in the 2024 NFL draft?\", \"The Bears selected USC quarterback Caleb Williams with the No. 1 pick in the 2024 NFL Draft.\" );   \n",
    "grader.apply( args );\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java (rjk 2.2.0)",
   "language": "java",
   "name": "rapaio-jupyter-kernel"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "java",
   "nbconvert_exporter": "script",
   "pygments_lexer": "java",
   "version": "22.0.2+9-70"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
